Fix/Feat (trunc avg pool): Update truncation and average pool behaviour #1042
+7
−8
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Proposal: Changes to
TruncQuant
andTruncAvgPool
Given that the behaviour of
TruncQuant
andTruncAvgPool
have existed for many years, and some examples do rely on them, I think it's worth opening this to be a discussion with the community.Note, as far as I can tell, the only place
TruncQuant
is used withTruncAvgPool
andTruncQuantAccumulator
, but I think the issues can be understood through the lens ofTruncAvgPool
.Motivation
I find the current implementation of
TruncQuant
andTruncAvgPool
to be a little troubling and incorrect outside of a specific use-case. I'll explain further down in this document.Current Implementation / Status
Before explaining my issues with the current implementation, let me first explain how
TruncQuant
andTruncAvgPool
currently work.Current
TruncQuant
ImplementationTruncQuant
shifts the integer representation of the input left or right (i.e., multiply by a power-of-two) proportional to the difference in the input and output bit width. That new integer value is then reinterpreted using the same scale as the input.Example:
Produced the following output:
I.e., if two's-complement is used, the binary representation of the underlying data has changed from
0x01111111
to0x1000
, but since the scale factor remains the same, the interpretation of these values changes from127
->8
. Effectively, the operation divides thevalue
of the input by2**(input_bit_width - output_bitwidth)
, while also rounding the result. Also note, the most-significant bits of the input 8-bit type are always kept (possibly with rounding).Finally, there is no guard against overflowing, if instead
x = torch.tensor([255])
in the above example, the output is:Which is an invalid
IntQuantTensor
, since the value16
cannot be represented in 4-bits (ifscale=1
,zero_point=0
).Current
TruncAvgPool
ImplementationThis needs to split into 2 sections, since one part is the behaviour of
torch.nn.functional.AvgPool2d
withIntQuantTensor
input, afterwards, the output of this functional call is modified and passed into aTruncQuant
module.IntQuantTensor
andtorch.nn.functional
When an
IntQuantTensor
is passed totorch.nn.functional.AvgPool2d
, effectively thescale
,zero_point
parameters are ignored. Thebit_width
field is updated to match the total number of bits required to represent the sum which underlies the average operation. Thevalue
field ofIntQuantTensor
is simply passed totorch.nn.functional.AvgPool2d
and passed onto the result.For example, the following code:
produces the following output:
Note that the
value
field of the output is no longer an integer multiple of the scale, meaning that this operations does not produce a validIntQuantTensor
.Scaling and Passing to
TruncQuant
Before the output of the
torch.nn.functional.AvgPool2d
call is passed toTruncQuant
thevalue
field of the intermediateIntQuantTensor
is multiplied by the value of the denominator in the average calculation - effectively turning the intermediate result from aAvgPool
to aSumPool
. This set corrects the issue between thevalue
andscale
from the intermediate output and converts this into a validIntQuantTensor
again.This intermediate result is then passed to
TruncQuant
and the truncation / reinterpretation process described in the previous section is performed.Overall, when the input and output bitwidths to a
TruncAvgPool
operation are the same, something somewhat sane occurs, for example:Which produces the following output:
Note, that since the same random seed is used from the previous example, the output here
IntQuantTensor(value=tensor([[[[0.2661]]]]), scale=0.00700347451493144, zero_point=0.0, bit_width=8.0, ...
is directly comparable to the output in the previous exampleIntQuantTensor(value=tensor([[[[0.4747]]]], ..., scale=0.00700347451493144, zero_point=0.0, bit_width=12.0, ...
.Note, that the
bit_width
has reduced (12 to 8) as expected, while the scale factor remains the same in both. The dequantized value has scaled approximately proportionally tok**2 / 16
, wherek**2
is the denominator of the average that occurred in the average pool, while16
is2**(input_bit_width - output_bitwidth)
as described in theTruncQuant
section. An argument can be made that this output is a desired one, even though the dequantized value significantly differs from the expected (unquantized) one.However, note that the input and output bitwidths to the
TruncAvgPool
are identical. If the bitwidth ofTruncAvgPool
is set to 12, the output becomes:A side-effect of the above functionality has converted the average pool to a sum pool. Conversely, if the bitwidth of
TruncAvgPool
is set to 4, the output becomes:Now, the dequantized output is approximately proportional to
k**2 / 256
to the expected (unquantized) value since theTruncQuant
module is applying a much larger dividing factor.Takeaways
When
input_bitwidth==output_bitwidth
in aTruncAvgPool
layer, we get somewhat sane functionality, but outside of this, I consider this behaviour at best "unintuitive" and at worst "buggy". Also, the sane behaviour forinput_bitwidth==output_bitwidth
seems to break several rules implicit with the Brevitas codebase, specifically:*QuantTensor
thatQuantTensor
should have valid datadivisor_override
parameter is usedProposal
In order to correct these issues, I propose the following:
AvgPool
functional call so that a validIntQuantTensor
is producedTruncQuant
operator should not reinterpret it's output, instead its scale should be adjusted by the amount of truncation that has occurredTruncQuant
to avoid overflow / underflowIntQuantTensor
s - this should be handled by the functional call andTruncQuant
input_bitwidth==output_bitwidth
when thedivisor_override
parameter is used as2**math.ceil(math.log2(k*k))
. Other scenarios require careful manipulation ofdivisor_override
to be achieved.Furthermore, I question the usefulness in hardware for the current
TruncAvgPool
for the following reasons:So finally, I recommend:
AvgPool
functional call